{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4ro48hGpifFm"
      },
      "source": [
        "Copyright 2020 Google LLC\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "you may not use this file except in compliance with the License.\n",
        "You may obtain a copy of the License at\n",
        "\n",
        "     https://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software\n",
        "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "See the License for the specific language governing permissions and\n",
        "limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SCbRWUlKimvH"
      },
      "source": [
        "JAX implementation of Gaussian process classification (GPC) using parallelized elliptical slice sampling (ESS). The algorithm is taken from Iain Murray, Ryan Prescott Adams, and David JC MacKay. \"[Elliptical Slice Sampling.](http://proceedings.mlr.press/v9/murray10a.html)\" (2010). \n",
        "\n",
        "We leverage recent theoretical advances that characterize the function-space prior of an ensemble of infinitely-wide NNs as a Gaussian process, termed the neural network Gaussian process (NNGP). We use the NNGP with a softmax link function to build a probabilistic model for multi-class classification and marginalize over the latent Gaussian outputs to sample from the posterior using ESS. This gives us a better understanding of the implicit prior NNs place on function space and allows a direct comparison of the calibration of the NNGP and its finite-width analogue. See Adlam *et al.* \"[Exploring the Uncertainty Properties of Neural Networks' Implicit Priors in the Infinite-Width Limit.](https://arxiv.org/abs/2010.07355)\" (2020)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "V5zZLyJdjKBj"
      },
      "source": [
        "##Imports"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kR7ChISWh6aO"
      },
      "outputs": [],
      "source": [
        "import time\n",
        "import numpy as onp\n",
        "import tensorflow_datasets as tfds\n",
        "import jax\n",
        "from jax import core\n",
        "from jax import device_put\n",
        "from jax import devices\n",
        "from jax import jit\n",
        "from jax import lax\n",
        "from jax import numpy as np\n",
        "from jax import pmap\n",
        "from jax import random\n",
        "from jax import vmap\n",
        "from jax import config as jax_config\n",
        "from jax.nn import softmax\n",
        "from jax.nn import log_softmax"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "executionInfo": {
          "elapsed": 311,
          "status": "ok",
          "timestamp": 1602775863746,
          "user": {
            "displayName": "Ben Adlam",
            "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GiwYcB0q-q5L_0TwtSgL_nnb0WSEspvc_2fCHej=s64",
            "userId": "10603308850729998596"
          },
          "user_tz": 240
        },
        "id": "Ed6auXKyqCFt",
        "outputId": "daef333e-3605-4411-c535-ee14690fcfff"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "/bin/sh: pip: command not found\n"
          ]
        }
      ],
      "source": [
        "!pip install -q git+https://www.github.com/google/neural-tangents"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WpZWwV6uqEFp"
      },
      "outputs": [],
      "source": [
        "import neural_tangents as nt\n",
        "from neural_tangents import stax"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RHDBdSGeiDJx"
      },
      "outputs": [],
      "source": [
        "# Use float64 when possible to safeguard against numerical issues.\n",
        "jax_config.update('jax_enable_x64', True)\n",
        "\n",
        "# Hardcode the maximum number of ESS steps allowed.\n",
        "# To avoid getting stuck in the while loop due to numerical issues.\n",
        "MAX_STEPS = 1e4"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zi1KDa6kjNOq"
      },
      "source": [
        "##Function Definitions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xTMbweXFiFKH"
      },
      "outputs": [],
      "source": [
        "def make_sda(arrays, devices=None):\n",
        "  \"\"\"Manually make a ShardedDeviceArray from a list of arrays.\"\"\"\n",
        "  if devices is None:\n",
        "    devices = jax.local_devices()\n",
        "  buffers = []\n",
        "  for arr, dev in zip(arrays, devices):\n",
        "    buffers.append(jax.interpreters.xla.device_put(arr, device=dev))\n",
        "    x_shape, x_dtype = arr.shape, arr.dtype\n",
        "  aval = core.ShapedArray((len(devices),) + x_shape, x_dtype)\n",
        "  return jax.pxla.ShardedDeviceArray(aval, buffers)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "F58qUrGPiHFV"
      },
      "outputs": [],
      "source": [
        "def es_step(key, f_old, prior_sampler, log_l):\n",
        "  \"\"\"Performs a single step of elliptical slice sampling.\n",
        "\n",
        "  Args:\n",
        "    key: A JAX PRNGKey.\n",
        "    f_old: The current state of the Markov chain stored as a 1D DeviceArray of\n",
        "      shape (dim,).\n",
        "    prior_sampler: We assume the prior distribution is a mean zero multivariate\n",
        "      Gaussian of dimension dim that can be parameterized by the Cholesky\n",
        "      decomposition its covariance matrix, denoted l. Then prior_sampler is a\n",
        "      function to sample from the prior that takes a PRNGKey and l as arguments\n",
        "      and returns a 1D DeviceArray of shape (dim,).\n",
        "    log_l: A function that returns the log-likelihood of the current state.\n",
        "  Returns:\n",
        "    A tuple (key, f_new, success, i), where key is a new PRNGKey, f_new is the\n",
        "    new state of the Markov chain, success is a bool, and i is an int. The bool\n",
        "    success indicates where the step of ESS was performed successfully. In some\n",
        "    cases the step can fail due to numerical precision; for some samples of the\n",
        "    randomness, the only acceptable transition for the chain is very close to\n",
        "    its current state. The chain can take many loops to complete the step, and\n",
        "    eventually the new state is equal to the current state (up to numerical\n",
        "    precision). So success is an indicator for when this failure occurs. Note as\n",
        "    long as it happens infrequently, the overall sampling is fine. Finally, i\n",
        "    indicates the number of iterations taken in the while loop.\n",
        "  \"\"\"\n",
        "  key, subkey = random.split(key)\n",
        "  nu = prior_sampler(subkey)\n",
        "  key, subkey = random.split(key)\n",
        "  log_y = log_l(f_old) + np.log(random.uniform(subkey))\n",
        "  key, subkey = random.split(key)\n",
        "  theta = 2 * np.pi * random.uniform(subkey)\n",
        "  theta_min, theta_max = theta - 2 * np.pi, theta\n",
        "\n",
        "  def _cond(vals):\n",
        "    _, f_new, _, _, _, i = vals\n",
        "    return np.logical_and(log_l(f_new) \u003c log_y, i \u003c= MAX_STEPS)\n",
        "\n",
        "  def _body(vals):\n",
        "    \"\"\"Body function for while loop to shrink the feasible region of theta.\"\"\"\n",
        "    key, f_new, theta, theta_min, theta_max, i = vals\n",
        "    i_new = i + 1\n",
        "    # if theta \u003c 0, then theta_min = theta\n",
        "    theta_min += (theta - theta_min) * (np.sign(-theta) + 1.) / 2.\n",
        "    # else theta_max = theta\n",
        "    theta_max += (theta - theta_max) * (np.sign(theta) + 1.) / 2.\n",
        "    key, subkey = random.split(key)\n",
        "    theta = theta_min + (theta_max - theta_min) * random.uniform(subkey)\n",
        "    f_new = f_old * np.cos(theta) + nu * np.sin(theta)\n",
        "\n",
        "    return key, f_new, theta, theta_min, theta_max, i_new\n",
        "\n",
        "  f_new = f_old * np.cos(theta) + nu * np.sin(theta)\n",
        "  key, f_new, theta, theta_min, theta_max, i = lax.while_loop(\n",
        "      _cond, _body, (key, f_new, theta, theta_min, theta_max, 0))\n",
        "\n",
        "  return (key, np.where(i \u003c= MAX_STEPS, f_new, f_old))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "A8n_zRZmh9e_"
      },
      "outputs": [],
      "source": [
        "def es_sample(key, dim, nc, prior_sampler, log_l, num_samples, burn_in,\n",
        "              trace_tuple=None, eval_tuple=None, logging_fn=print,\n",
        "              init_state=None):\n",
        "  \"\"\"Sample from the posterior using MCMC given by elliptical slice sampling.\n",
        "\n",
        "  Args:\n",
        "    key: A JAX PRNGKey.\n",
        "    dim: An int specifying the dimension of the state, since it is 1D.\n",
        "    nc: An int specifying the number of classes in the classification.\n",
        "    prior_sampler: We assume the prior distribution is a mean zero multivariate\n",
        "      Gaussian of dimension dim that can be parameterized by the Cholesky\n",
        "      decomposition its covariance matrix, denoted l. Then prior_sampler is a\n",
        "      function to sample from the prior that takes a PRNGKey and l as arguments\n",
        "      and returns a 1D DeviceArray of shape (dim,).\n",
        "    log_l: A function that returns the log-likelihood of the current state.\n",
        "    num_samples: An int for the number of samples from or steps of MCMC.\n",
        "    burn_in: An int specifying the number of steps to throw out from the MCMC.\n",
        "    trace_tuple: A tuple (fn, i), where fn is a function to apply to the state\n",
        "      of the Markov chain and save as a DeviceArray, and i is an int specifying\n",
        "      the interval at which to save the trace. Note saving lots of data\n",
        "      frequently can cause OOMs.\n",
        "    eval_tuple: A tuple (fn, i), where fn is a function to apply to the current\n",
        "      posterior given by the Markov chain and log the result, and i is an int\n",
        "      specifying the interval at which to evaluate.\n",
        "    logging_fn: Defaults to logging.info, but if using the code in Colab print()\n",
        "      can be used.\n",
        "    init_state: A 1D DeviceArray of shape (dim,) that is the initial state for\n",
        "      the ESS.\n",
        "  Returns:\n",
        "    A tuple (p, eval_traces, step_norms), where p is a DeviceArray for the\n",
        "    posterior, eval_traces is a DeviceArray containing statisitics given by\n",
        "    eval_fn from the trace, and step_norms contains the steps sizes of Markov\n",
        "    chain's transitions.\n",
        "  \"\"\"\n",
        "  start_time = time.time()\n",
        "  loop_time = start_time\n",
        "  logging_fn('Starting MCMC sampling...\\n')\n",
        "\n",
        "  p_num = jax.local_device_count()\n",
        "  key = np.reshape(\n",
        "      random.split(key, jax.device_count()),\n",
        "      [jax.host_count(), jax.local_device_count(), 2])[jax.host_id()]\n",
        "  total_samples = int(num_samples // p_num + burn_in)\n",
        "  if init_state is None:\n",
        "    sample = pmap(lambda x: np.zeros([dim]))(np.arange(p_num))\n",
        "  else:\n",
        "    sample = pmap(lambda x: init_state)(np.arange(p_num))\n",
        "  p = pmap(lambda x: np.zeros([dim//nc, nc]))(np.arange(p_num))\n",
        "  def accumulate_softmax(x, y, t):\n",
        "    out = ((t-1.)/t) * x + (1./t) * softmax(np.reshape(y, [dim//nc, nc]))\n",
        "    # Guess current state while still in burn in phase.\n",
        "    out = np.where(t \u003e 0, out, softmax(np.reshape(y, [dim//nc, nc])))\n",
        "    return out / np.sum(out, axis=-1, keepdims=True)\n",
        "\n",
        "  # Set up function to evaluate the current p(y|x).\n",
        "  eval_print_steps = total_samples\n",
        "  if eval_tuple is not None:\n",
        "    eval_print, eval_print_steps = eval_tuple\n",
        "\n",
        "  # Set up function to apply to current sample and save.\n",
        "  eval_traces = []\n",
        "  trace_fn_steps = total_samples\n",
        "  if trace_tuple is not None:\n",
        "    trace_fn, trace_fn_steps = trace_tuple\n",
        "    trace_fn = pmap(trace_fn)\n",
        "\n",
        "  epoch = int(min(eval_print_steps, trace_fn_steps))\n",
        "\n",
        "  def body_fun(i, vals):\n",
        "    key, sample, p = vals\n",
        "    key, sample = es_step(key, sample, prior_sampler, log_l)\n",
        "    p = accumulate_softmax(p, sample, i - burn_in + 1.)\n",
        "    return key, sample, p\n",
        "\n",
        "  @pmap\n",
        "  def for_loop_fn(i, key, sample, p):\n",
        "    return lax.fori_loop(i, i + epoch, body_fun, (key, sample, p))\n",
        "\n",
        "  for i in range(0, total_samples, epoch):\n",
        "    key, sample, p = for_loop_fn(i * np.ones([p_num]), key, sample, p)\n",
        "\n",
        "    i += epoch\n",
        "    if eval_tuple is not None:\n",
        "      eval_print(p, i)\n",
        "\n",
        "    if trace_tuple is not None:\n",
        "      eval_traces.append(trace_fn(sample))\n",
        "\n",
        "    logging_fn('Completed step {}/{} in {:.3f} mins at {}\\n'.format(\n",
        "        i, total_samples, (time.time() - loop_time) / 60.,\n",
        "        time.asctime(time.localtime())))\n",
        "    loop_time = time.time()\n",
        "\n",
        "  p = np.mean(p, axis=0)\n",
        "\n",
        "  logging_fn('Sampling complete in {:.3f} mins.'.format(\n",
        "      (time.time() - start_time) / 60.))\n",
        "\n",
        "  return p, np.array(eval_traces)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ndM2Rx1Oh9hn"
      },
      "outputs": [],
      "source": [
        "def gpc_predict(\n",
        "    k_fn, k_scale, x0, y0, x1, key, num_samples, burn_in, diag_reg=0.,\n",
        "    ess_dtype=np.float32, trace_tuple=None, eval_tuple=None,\n",
        "    logging_fn=print):\n",
        "  \"\"\"Approximates test set posterior given a kernel and training data using ESS.\n",
        "\n",
        "  Args:\n",
        "    k_fn: A neural_tangents kernel function. Can be None if the kernel's\n",
        "      Cholesky decomposition will be loaded from CNS.\n",
        "    k_scale: A scalar to rescale the kernel matrix k. Note this is only used if\n",
        "      the kernel is computed using the kernel_fn, not if it is loaded from CNS.\n",
        "    x0: An array of training points.\n",
        "    y0: An array of labels for the training data.\n",
        "    x1: An array of test points.\n",
        "    key: A PRNGKey.\n",
        "    num_samples: A positive integer specifying the number of steps of MCMC.\n",
        "    burn_in: A positive integer specifying to burn-in for the MCMC sampling.\n",
        "    diag_reg: A nonnegative float to add to the diagonal and regularize the\n",
        "      kernel matrix.\n",
        "    ess_dtype: The dtype for the computed Cholesky and the state of the ess.\n",
        "    trace_tuple: A tuple (fn, num), where fn is a function applied to the\n",
        "      current state and num is an integer specifying how often the output is\n",
        "      saved.\n",
        "    eval_tuple: A tuple (fn, num), where fn is a function applied to the current\n",
        "      probabilities to log their performance every num steps.\n",
        "    logging_fn: Function to log progress.\n",
        "\n",
        "  Returns:\n",
        "    Two arrays containing the posterior on the training set and the test set.\n",
        "    Consequently they have shapes (n0, nc) and (n1, nc), where each row is a\n",
        "    distribution.\n",
        "  \"\"\"\n",
        "  start_time = time.time()\n",
        "  logging_fn('========= jax.device_count(): %s' % jax.device_count())\n",
        "  logging_fn('========= local_device_count: %s' % jax.local_device_count())\n",
        "  logging_fn('========= devices: %s' % ',\"'.join(map(str, jax.devices())))\n",
        "  logging_fn('========= host_id: %s' % jax.host_id())\n",
        "\n",
        "  if k_fn is None and cns_l is None:\n",
        "    raise ValueError('Either k_fn or cns_l must be specified!')\n",
        "\n",
        "  n0, nc = y0.shape\n",
        "  n1 = x1.shape[0]\n",
        "\n",
        "  init_state = None\n",
        "\n",
        "\n",
        "  logging_fn('Computing kernel matrices...')\n",
        "  k = k_fn(np.vstack([x0, x1]), None, 'nngp')\n",
        "  logging_fn('Computed kernel matrices in {:.3f} mins.'.format(\n",
        "      (time.time() - start_time) / 60.))\n",
        "\n",
        "  # Computing initial state.\n",
        "  logging_fn('Computing initial state for ESS...')\n",
        "  y1_hat = k[:n0, :n0] @ np.linalg.solve(\n",
        "      k[:n0, :n0] + diag_reg * np.eye(n0), y0)\n",
        "  init_state = np.reshape(np.vstack([y0, y1_hat]), [-1])\n",
        "  logging_fn('Computing initial state for ESS in {:.3f} mins.'.format(\n",
        "      (time.time() - start_time) / 60.))\n",
        "\n",
        "  # Compute the Cholesky once, adding a diagonal regularizer for stability.\n",
        "  start_time = time.time()\n",
        "  logging_fn('Computing Cholesky decomposition...')\n",
        "  l = onp.sqrt(k_scale) * onp.linalg.cholesky(\n",
        "      k + diag_reg * np.trace(k) / (n0 + n1) * np.eye(n0 + n1))\n",
        "  # ANY CASTING SHOULD HAPPEN HERE, AFTER THE CHOLESKY COMPUTATION.\n",
        "  l = jax.device_put(l, jax.devices('cpu')[0])\n",
        "  l = np.array(l, dtype=ess_dtype)  # Cast Cholesky to desired precision.\n",
        "  logging_fn('Computed Cholesky decomposition in {:.3f} mins.'.format(\n",
        "      (time.time() - start_time) / 60.))\n",
        "\n",
        "  def prior_sampler(key):\n",
        "    normal_samples = random.normal(key, [n0+n1, nc])\n",
        "    normal_samples = np.array(normal_samples, dtype=ess_dtype)\n",
        "    # The Cholesky of the Kronecker is the Kronecker of the Choleskys, i.e.\n",
        "    # np.kron(l, np.eye(nc)). Avoid instantiating large matrix.\n",
        "    # NB. While this is correct, it seems buggy on TPU, i.e. there are different\n",
        "    # answers for the code below and actually using the Kronecker product.\n",
        "    return np.reshape(l @ normal_samples, [-1])\n",
        "\n",
        "  def log_l(f):\n",
        "    \"\"\"Log-likelihood: p(data|f_train, f_test) == p(data|f_train).\"\"\"\n",
        "    f_ = np.reshape(f[:n0*nc], [n0, nc])\n",
        "    p_ = log_softmax(f_)\n",
        "    # Assumes y0 contains one-hot labels.\n",
        "    return np.sum(p_ * y0)\n",
        "\n",
        "  p, eval_trace = es_sample(\n",
        "      key, nc*(n0+n1), nc, prior_sampler, log_l, num_samples, burn_in,\n",
        "      trace_tuple, eval_tuple, logging_fn=logging_fn, init_state=init_state)\n",
        "  return p[:n0], p[n0:], eval_trace"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "B_CatUmomBV6"
      },
      "outputs": [],
      "source": [
        "DATASETS = {\n",
        "    'mnist': 10,\n",
        "    'cifar10': 10,\n",
        "    'cifar100': 100,\n",
        "}\n",
        "\n",
        "\n",
        "def to_one_hot(label, nc):\n",
        "  return (np.arange(nc) == label).astype(int)\n",
        "\n",
        "\n",
        "def get_dataset(train_dataset, test_dataset, n0, n1, classes, valid_set=False,\n",
        "                flatten=True):\n",
        "  \"\"\"Function to load data and preprocess it.\n",
        "  \n",
        "  Note this implementation is slow due to filtering and preprocessing.\"\"\"\n",
        "  if train_dataset not in DATASETS or test_dataset not in DATASETS:\n",
        "    raise ValueError('Dataset \"{}\" or \"{}\" not recognized! Choose from: '\n",
        "                     '{}'.format(train_dataset, test_dataset, DATASETS))\n",
        "  ds_train = tfds.load(name=train_dataset)\n",
        "  ds_test = tfds.load(name=test_dataset)\n",
        "\n",
        "  if classes is None:\n",
        "    nc = DATASETS[train_dataset]\n",
        "    filter_fn = lambda x: True\n",
        "  elif isinstance(classes, int):\n",
        "    if classes \u003e DATASETS[train_dataset]:\n",
        "      raise ValueError('Requesting {} classes from {}, but it only has {}'\n",
        "                       'classes!'.format(\n",
        "                           classes, train_dataset, DATASETS[train_dataset]))\n",
        "    nc = classes\n",
        "    classes = set(range(nc))\n",
        "    filter_fn = lambda x: x['label'] in classes\n",
        "  elif isinstance(classes, set) or isinstance(classes, list):\n",
        "    classes = set(classes)  # Remove any duplicate classes.\n",
        "    nc = len(classes)\n",
        "    filter_fn = lambda x: x['label'] in classes\n",
        "  else:\n",
        "    raise ValueError(\n",
        "        '\"classes\" must be type None, int, or set! Given {}'.format(\n",
        "            type(classes)))\n",
        "\n",
        "  ds_train_np = tfds.as_numpy(ds_train)\n",
        "  ds_test_np = tfds.as_numpy(ds_test)\n",
        "  # Keep datapoints as regular NumPy arrays on CPU for Cholesky computation.\n",
        "  # NB. Currently, requesting float64 does not work.\n",
        "  x0_ = onp.array([np.array(x['image']).astype(onp.float64)\n",
        "                   for x in ds_train_np['train'] if filter_fn(x)])\n",
        "  x1_ = onp.array([np.array(x['image']).astype(onp.float64)\n",
        "                   for x in ds_test_np['test'] if filter_fn(x)])\n",
        "\n",
        "  ds_train_np = tfds.as_numpy(ds_train)\n",
        "  ds_test_np = tfds.as_numpy(ds_test)\n",
        "  y0_ = np.array([to_one_hot(x['label'], nc)\n",
        "                  for x in ds_train_np['train'] if filter_fn(x)])\n",
        "  y1_ = np.array([to_one_hot(x['label'], nc)\n",
        "                  for x in ds_test_np['test'] if filter_fn(x)])\n",
        "  y0 = y0_[:n0]\n",
        "  y1 = y1_[:n1]\n",
        "\n",
        "  if valid_set:\n",
        "    if x0_.shape[0] \u003c n0 + n1:\n",
        "      raise ValueError('Validation set is taken from end of training split. '\n",
        "                       'So n0+n1 cannot exceed total training points in '\n",
        "                       'requested dataset, but received {} and {}'.format(\n",
        "                           n0+n1, x0_.shape[0]))\n",
        "    x1_ = x0_[-n1:]\n",
        "    y1 = y0_[-n1:]\n",
        "\n",
        "  mean = onp.mean(x0_[:n0])\n",
        "  std = onp.std(x0_[:n0])\n",
        "  if flatten:\n",
        "    x0_ = onp.reshape(x0_[:n0], (n0, -1))\n",
        "    x1_ = onp.reshape(x1_[:n1], (n1, -1))\n",
        "  else:\n",
        "    x0_ = x0_[:n0]\n",
        "    x1_ = x1_[:n1]\n",
        "  x0 = (x0_ - mean) / std\n",
        "  x1 = (x1_ - mean) / std\n",
        "  # NOTE: CURRENTLY THIS CASTS THE ARRAYS TO FLOAT32!\n",
        "  x0 = device_put(x0, devices('cpu')[0])\n",
        "  x1 = device_put(x1, devices('cpu')[0])\n",
        "\n",
        "  return (x0, y0), (x1, y1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jsWWSCRUm7Za"
      },
      "source": [
        "##Define Parameters"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "c0LeFdT_m-di"
      },
      "outputs": [],
      "source": [
        "n0 = 1000\n",
        "n1 = 1000\n",
        "nc = 10\n",
        "kernel_batch_size = None\n",
        "train_dataset = 'cifar10'\n",
        "test_dataset = 'cifar10'\n",
        "valid_set = False\n",
        "kernel_type = 'fc'\n",
        "activation = 'erf'\n",
        "W_std = 1.4142135624\n",
        "b_std = 0.\n",
        "k_scale = 1.\n",
        "k_depth = 5\n",
        "diag_reg = 1e-6\n",
        "ess_dtype = np.float32\n",
        "key = random.PRNGKey(0)\n",
        "mcmc_steps = 1e5\n",
        "eval_steps = 1e4\n",
        "burn_in = 1e4\n",
        "iterations = 1\n",
        "save_trace = False"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fP1-7bNrjVNu"
      },
      "source": [
        "##Load Data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HDeDppV5jWwA"
      },
      "outputs": [],
      "source": [
        "flatten = False if kernel_type == 'cnn' else True\n",
        "(x0, y0), (x1, y1) = get_dataset(train_dataset, test_dataset, n0, n1, nc,\n",
        "                                 valid_set=valid_set, flatten=flatten)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Vid0VmiTje50"
      },
      "source": [
        "##Define Kernel"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DTKO-oeYjhLu"
      },
      "outputs": [],
      "source": [
        "if activation == 'relu':\n",
        "  act = stax.Relu()\n",
        "elif activation == 'erf':\n",
        "  act = stax.Erf()\n",
        "\n",
        "if kernel_type == 'fc':\n",
        "  collect_layers = [stax.Dense(512, W_std=W_std, b_std=b_std), act]*k_depth\n",
        "  collect_layers += [stax.Dense(1, W_std=W_std, b_std=b_std)]\n",
        "  _, _, k_fn = stax.serial(*collect_layers)\n",
        "\n",
        "elif kernel_type == 'cnn':\n",
        "  conv = functools.partial(stax.Conv, W_std=W_std, b_std=b_std,\n",
        "                            padding='SAME', parameterization='ntk')\n",
        "  collect_layers = [conv(512, (3, 3)), act]*k_depth\n",
        "  collect_layers += [stax.Flatten(),\n",
        "                      stax.Dense(1, W_std, b_std, parameterization='ntk')]\n",
        "  _, _, k_fn = stax.serial(*collect_layers)\n",
        "\n",
        "else:\n",
        "  raise ValueError('Kernel type {} not recognized! Choose either fc or '\n",
        "                    'cnn.'.format(kernel_type))\n",
        "  \n",
        "if kernel_batch_size is not None:\n",
        "  # Recommended batch size ~25 for pooling, ~800 for flattening\n",
        "  if (n0+n1) % kernel_batch_size * local_device_count() != 0:\n",
        "    raise ValueError('Device count times batch size must divide the training '\n",
        "                      'set plus test set size! Received {} and {}.'.format(\n",
        "                          kernel_batch_size*local_device_count(), n0+n1))\n",
        "  k_fn = nt.batch(k_fn, batch_size=kernel_batch_size, store_on_device=False)\n",
        "\n",
        "else:\n",
        "  k_fn = jit(k_fn, static_argnums=(1, 2))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qkMFByzcjRxg"
      },
      "source": [
        "##Run ESS"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Uz2DAfEFksSu"
      },
      "outputs": [],
      "source": [
        "@vmap\n",
        "def acc_(y, y_hat):\n",
        "  return np.argmax(y) == np.argmax(y_hat)\n",
        "acc = lambda x, y: np.mean(acc_(x, y))\n",
        "\n",
        "def eval_print(p, t):\n",
        "  p = onp.array(p, dtype=onp.float32)\n",
        "  # Probabilities are 0 during burn in.\n",
        "  if np.max(p) \u003c 1. / nc:\n",
        "    return None\n",
        "  p = p / onp.sum(p, axis=-1, keepdims=True)\n",
        "\n",
        "  if p.ndim == 3:\n",
        "    for i, p_ in enumerate(p):\n",
        "      p0, p1 = p_[:n0], p_[n0:]\n",
        "      print(\"Train acc for chain {}: {}\".format(i, acc(y0, p0)))\n",
        "      print(\"Test acc for chain {}: {}\".format(i, acc(y1, p1)))\n",
        "\n",
        "    p = np.mean(p, axis=0)\n",
        "  p0, p1 = p[:n0], p[n0:]\n",
        "  print(\"Train acc: {}\".format(acc(y0, p0)))\n",
        "  print(\"Test acc: {}\".format(acc(y1, p1)))\n",
        "\n",
        "eval_tuple = (eval_print, eval_steps)\n",
        "\n",
        "trace_tuple = None\n",
        "if save_trace:\n",
        "  # Save trace of subset of data.\n",
        "  def trace_fn(sample):\n",
        "    x = np.reshape(sample, (-1, nc))\n",
        "    return np.vstack([x[:5], x[n0: n0+5]])\n",
        "\n",
        "  trace_tuple = (trace_fn, 1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uXwEXWBKivg4"
      },
      "outputs": [],
      "source": [
        "total_p0, total_p1 = 0., 0.\n",
        "for i in range(iterations):\n",
        "  key, subkey = random.split(key)\n",
        "  p0, p1, eval_trace = gpc_predict(\n",
        "      k_fn, k_scale, x0, y0, x1, subkey, num_samples=mcmc_steps,\n",
        "      burn_in=burn_in, diag_reg=diag_reg,\n",
        "      ess_dtype=ess_dtype, trace_tuple=trace_tuple, eval_tuple=eval_tuple)\n",
        "  total_p0 = (i * total_p0 + p0) / (i+1)\n",
        "  total_p1 = (i * total_p1 + p1) / (i+1)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [
        "V5zZLyJdjKBj",
        "zi1KDa6kjNOq",
        "jsWWSCRUm7Za",
        "fP1-7bNrjVNu",
        "Vid0VmiTje50",
        "qkMFByzcjRxg"
      ],
      "last_runtime": {
        "build_target": "//learning/deepmind/dm_python:dm_notebook3",
        "kind": "private"
      },
      "name": "ESS for GPC with the NNGP.ipynb",
      "provenance": [
        {
          "file_id": "1Fzx71NEiXmAwlCmvKyMvCkjpNl0x-rMk",
          "timestamp": 1603475071229
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
